{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/anaconda3/envs/cpu/lib/python3.6/site-packages/gym/core.py:26: UserWarning: \u001b[33mWARN: Gym minimally supports python 3.6 as the python foundation not longer supports the version, please update your version to 3.7+\u001b[0m\n",
      "  \"Gym minimally supports python 3.6 as the python foundation not longer supports the version, please update your version to 3.7+\"\n",
      "/root/anaconda3/envs/cpu/lib/python3.6/site-packages/gym/core.py:330: DeprecationWarning: \u001b[33mWARN: Initializing wrapper in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.\u001b[0m\n",
      "  \"Initializing wrapper in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.\"\n",
      "/root/anaconda3/envs/cpu/lib/python3.6/site-packages/gym/wrappers/step_api_compatibility.py:40: DeprecationWarning: \u001b[33mWARN: Initializing environment in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.\u001b[0m\n",
      "  \"Initializing environment in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.\"\n",
      "/root/anaconda3/envs/cpu/lib/python3.6/site-packages/gym/core.py:52: DeprecationWarning: \u001b[33mWARN: The argument mode in render method is deprecated; use render_mode during environment initialization instead.\n",
      "See here for more information: https://www.gymlibrary.ml/content/api/\u001b[0m\n",
      "  \"The argument mode in render method is deprecated; \"\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQEAAAD8CAYAAAB3lxGOAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAARnUlEQVR4nO3da2xU953G8e9jY5sYh3CzCcUkuKmjJpA2F5fSNt02aWjYJC15U5VKWfEiUqQqK7XaSl3YSl1V2kjZXanqq7yI2qqo7RZRpVUQqtpFbtL0kgAmUK4hOKUBBwcbCMQkgG+/feETdgoOPtgznnH/z0ey5sx//jN+RjFPzjlz5hxFBGaWrqpyBzCz8nIJmCXOJWCWOJeAWeJcAmaJcwmYJa5kJSBppaSDkjolrS3V7zGziVEpjhOQVA28CqwAuoDtwFciYn/Rf5mZTUip1gSWAZ0R8ZeI6Ac2AKtK9LvMbAKmleh1FwJHC+53AR9/v8nz5s2LxYsXlyiKmQHs2LHjREQ0XjpeqhLQKGN/s90h6THgMYAbbriBjo6OEkUxMwBJr482XqrNgS5gUcH9ZuBY4YSIeDoi2iKirbHxsnIys0lSqhLYDrRKapFUC6wGNpXod5nZBJRkcyAiBiX9M/AboBr4YUTsK8XvMrOJKdU+ASLiV8CvSvX6ZlYcPmLQLHEuAbPEuQTMEucSMEucS8AscS4Bs8S5BMwS5xIwS5xLwCxxLgGzxLkEzBLnEjBLnEvALHEuAbPEuQTMEucSMEucS8AscS4Bs8S5BMwS5xIwS5xLwCxxLgGzxLkEzBLnEjBLnEvALHEuAbPEuQTMEucSMEucS8AscS4Bs8S5BMwS5xIwS5xLwCxxLgGzxLkEzBI3ZglI+qGkHkl7C8bmSNoi6VB2O7vgsXWSOiUdlHR/qYKbWXHkWRP4EbDykrG1QHtEtALt2X0k3QqsBpZkz3lKUnXR0ppZ0Y1ZAhHxAnDqkuFVwPpseT3wcMH4hoi4EBGHgU5gWXGimlkpjHefwPyI6AbIbpuy8YXA0YJ5XdnYZSQ9JqlDUkdvb+84Y5jZRBV7x6BGGYvRJkbE0xHRFhFtjY2NRY5hZnmNtwSOS1oAkN32ZONdwKKCec3AsfHHM7NSG28JbALWZMtrgGcLxldLqpPUArQC2yYW0cxKadpYEyT9DPgsME9SF/DvwJPARkmPAkeALwFExD5JG4H9wCDweEQMlSi7mRXBmCUQEV95n4c+9z7znwCemEgoM5s8PmLQLHEuAbPEuQTMEucSMEucS8AscS4Bs8S5BMwS5xIwS5xLwCxxLgGzxLkEzBLnEjBLnEvALHEuAbPEuQTMEucSMEucS8AscS4Bs8S5BMwS5xIwS5xLwCxxLgGzxLkEzBLnEjBLnEvALHEuAbPEuQTMEucSMEucS8AscS4Bs8S5BMwS5xIwS5xLwCxxLgGzxI1ZApIWSXpO0gFJ+yR9LRufI2mLpEPZ7eyC56yT1CnpoKT7S/kGzGxi8qwJDALfiIhbgOXA45JuBdYC7RHRCrRn98keWw0sAVYCT0mqLkV4M5u4MUsgIroj4uVsuQ84ACwEVgHrs2nrgYez5VXAhoi4EBGHgU5gWZFzm1mRXNU+AUmLgTuArcD8iOiGkaIAmrJpC4GjBU/rysbMrALlLgFJDcAzwNcj4u0rTR1lLEZ5vcckdUjq6O3tzRvDzIosVwlIqmGkAH4aEb/Iho9LWpA9vgDoyca7gEUFT28Gjl36mhHxdES0RURbY2PjePOb2QTl+XRAwA+AAxHx3YKHNgFrsuU1wLMF46sl1UlqAVqBbcWLbGbFNC3HnE8B/wTskbQrG/s34Elgo6RHgSPAlwAiYp+kjcB+Rj5ZeDwihood3MyKY8wSiIg/MPp2PsDn3uc5TwBPTCCXmU0SHzFoljiXgFniXAJmiXMJmCXOJWCWOJeAWeJcAmaJcwmYJc4lYJY4l4BZ4vJ8d8AMgIhg6OxZLhw/TgwOUnPdddQ2NUFVFSPfM7OpyCVguQz393Pyuefo2byZC93dxOAg02bO5Lq77mLBl79M7fz5LoIpyiVgYxoeHOTNX/6SN3/+c6K//+L44OnTnGxv53xXFy3f+IaLYIryPgG7ooigb/dujj/zDNHfz9v9/fyos5P/3ruX3adOERG8c/Agb/z4xzA8XO64Ng5eE7Ax9W7ezPD58/QNDPDtnTv5Q8/ISaR+3dXFf9x5J59oauLMtm2ce/116j/4wTKntavlNQG7sggGzpwB4I133+VPPT0XHzozMMD/Hhs5c9zwhQsMnj1blog2MS4By622qoq66r+9hMTMmpqLy2e2byfisnPKWoVzCdjYsp19LQ0N/OtttzGvro666mruXbCAR1tbL07r27WLIa8NTDneJ2BXJnFdWxvvvvoqkniwuZk7587l3OAgC2fMYHrBmsG5ri769uxh9ic/WcbAdrW8JmBXJImZt99O1TXXXLz/gfp6bpo5828KAIChoZFNgiGfV3YqcQnYmOpbWqi7/vpcc8+8/DKDfX0lTmTF5BKwMam2lnn33Xdx38CVDJ45w9s7d3oH4RTiErAxSWJmWxvV9fVjTx4e5vTWrT5waApxCVgutY2NNCxdmmvuO6+8Qv/JkyVOZMXiErBcVF3N7OXLc20SDJw6xemXXvImwRThErBcJNGwZAnTZs7MNf/0iy8Sg4MlTmXF4BKw3Grnz2fGzTfnmvvOoUOc7+oqcSIrBpeAXZV5K1bApccHjCL6+71JMEW4BCw3STQsXUpdU1Ou+W/v2sXw+fMlTmUT5RKwq1JdX891H/tYrrnnDh/2JsEU4BKwq5N9l0C1tWNOHT5/ntNbt3qToMK5BOyqSKK+tZW6xsZc889s28bwuXMlTmUT4RKwq1ZdX8/Mu+7KNffc0aP07dnjtYEK5hKwqyaJ2XffjXJ8SsDQEKdfeqn0oWzcXAI2LtMXLeKaxYtzzT27fz+D2SnKrPKMWQKSpkvaJunPkvZJ+k42PkfSFkmHstvZBc9ZJ6lT0kFJ95fyDVh5VNfXM+eee3LNvdDdzTuvvlriRDZeedYELgD3RsRHgduBlZKWA2uB9ohoBdqz+0i6FVgNLAFWAk9JyrHeaFOJJGZ+9KNUNzTkmn/qhRd8spEKNWYJxIj3ThxXk/0EsApYn42vBx7OllcBGyLiQkQcBjqBZcUMbZVhenNz/k2CvXsZOHWqtIFsXHLtE5BULWkX0ANsiYitwPyI6AbIbt87jGwhcLTg6V3Z2KWv+ZikDkkdvb29E3gLVjZVVcz59KdzTR146y3e+tOf/ClBBcpVAhExFBG3A83AMklX+mL5aN81vey/fEQ8HRFtEdHWmPMzZ6sskph5xx35vlkYwZkdO4iBgdIHs6tyVZ8ORMRp4HlGtvWPS1oAkN2+d1WKLmBRwdOagWMTDWqVqWbuXGZ8+MO55r772mv0F1y8xCpDnk8HGiXNypavAe4DXgE2AWuyaWuAZ7PlTcBqSXWSWoBWYFuRc1uF0LRpzL3nnlzHDAz19fn8gxUoz5rAAuA5SbuB7YzsE9gMPAmskHQIWJHdJyL2ARuB/cCvgccjwruF/05JYsYtt1Azd26u+ad+/3ufbKTCjHnxkYjYDdwxyvhJ4HPv85wngCcmnM6mhJpZs7j2tts42d4+5txzR45w/sgR6m+6aRKSWR4+YtAmTFVVzPnMZ3Kdf3D43Xc50d7uTYIK4hKwoqi/6SamL7zsk+BRnd2zx9csrCAuASuK6oYGGm67Ldfcc11dnHv99RInsrxcAlY0c++55+I1C69oaIi3/vhHbxJUCJeAFYUk6ltacm8SnNmxgyFfs7AiuASsaFRby6zly3PNHTh5krOvvOK1gQrgErCikcSsj3881yZBDAxwZvt2cAmUnUvAiqp2/nxmfOhDueae6eig/8SJEieysbgErKiq6uqYecdlx5aNauDkSc7u3etNgjJzCVhRSeK6ZcuoqqvLNf/U737nTYIycwlY0U3/wAe49iMfyTX3nc5OLnR3lziRXYlLwIpO06aNHEacw1BfH327d3uToIxcAlYSDUuXUpvzZDGnt271yUbKyCVgJVEzaxYNS5bkmnv2wAHOHz069kQrCZeAlYY0cuHSHCcbGT53jrdefNGbBGXiErCSeO+U5DWzZuWaf/qll3wZ8zJxCVjJVDc0MOsTn8g198Kbb3L2wAGvDZSBS8BKRlVVzLn7bjRtzBNYEf39nPjNbyYhlV3KJWAldc2NN+a+QEn/iRM+cKgMXAJWUlX19bkPI7bycAlYSUli3uc/n+tsxDNuvjnXeQqtuFwCVnK1TU00feELV9w3ULdgAU0PPjiJqew9LgErOUk0PvAAjQ89hGprL3u89vrrueGrX6Vu4ULkNYFJN/ZuW7MiqJ4+nYWPPELDLbdw6vnnOf/GG1TV1XHt0qXMW7HCBVBGLgGbNFXZ6cdmLVtGDA8DjFy+THIBlJFLwCaVJKiuznXtQpsc3idgljiXgFniXAJmiXMJmCXOJWCWOJeAWeJcAmaJcwmYJS53CUiqlrRT0ubs/hxJWyQdym5nF8xdJ6lT0kFJ95ciuJkVx9WsCXwNOFBwfy3QHhGtQHt2H0m3AquBJcBK4ClJPjzMrELlKgFJzcCDwPcLhlcB67Pl9cDDBeMbIuJCRBwGOoFlRUlrZkWXd03ge8A3geGCsfkR0Q2Q3TZl4wuBwpPId2VjZlaBxiwBSQ8BPRGxI+drjvZ1sMtOHCfpMUkdkjp6e3tzvrSZFVueNYFPAV+U9FdgA3CvpJ8AxyUtAMhue7L5XcCiguc3A8cufdGIeDoi2iKirTHn5arMrPjGLIGIWBcRzRGxmJEdfr+NiEeATcCabNoa4NlseROwWlKdpBagFdhW9ORmVhQTOZ/Ak8BGSY8CR4AvAUTEPkkbgf3AIPB4RAxNOKmZlYQq4YovbW1t0dHRUe4YZn/XJO2IiLZLx33EoFniXAJmiXMJmCXOJWCWOJeAWeJcAmaJcwmYJc4lYJY4l4BZ4lwCZolzCZglziVgljiXgFniXAJmiXMJmCXOJWCWOJeAWeJcAmaJcwmYJc4lYJY4l4BZ4lwCZolzCZglziVgljiXgFniXAJmiXMJmCXOJWCWOJeAWeJcAmaJcwmYJc4lYJY4l4BZ4lwCZolzCZglziVgljiXgFniXAJmiVNElDsDknqBd4AT5c6S0zymTlaYWnmdtXRujIjGSwcrogQAJHVERFu5c+QxlbLC1MrrrJPPmwNmiXMJmCWukkrg6XIHuApTKStMrbzOOskqZp+AmZVHJa0JmFkZlL0EJK2UdFBSp6S15c4DIOmHknok7S0YmyNpi6RD2e3sgsfWZfkPSrp/krMukvScpAOS9kn6WqXmlTRd0jZJf86yfqdSsxb8/mpJOyVtrvSs4xYRZfsBqoHXgA8CtcCfgVvLmSnL9Q/AncDegrH/AtZmy2uB/8yWb81y1wEt2fupnsSsC4A7s+VrgVezTBWXFxDQkC3XAFuB5ZWYtSDzvwD/A2yu5L+DifyUe01gGdAZEX+JiH5gA7CqzJmIiBeAU5cMrwLWZ8vrgYcLxjdExIWIOAx0MvK+JkVEdEfEy9lyH3AAWFiJeWPE2exuTfYTlZgVQFIz8CDw/YLhisw6EeUugYXA0YL7XdlYJZofEd0w8g8PaMrGK+Y9SFoM3MHI/2ErMm+2er0L6AG2RETFZgW+B3wTGC4Yq9Ss41buEtAoY1Pt44qKeA+SGoBngK9HxNtXmjrK2KTljYihiLgdaAaWSVp6hellyyrpIaAnInbkfcooY1Pib7ncJdAFLCq43wwcK1OWsRyXtAAgu+3Jxsv+HiTVMFIAP42IX2TDFZsXICJOA88DK6nMrJ8Cvijpr4xspt4r6ScVmnVCyl0C24FWSS2SaoHVwKYyZ3o/m4A12fIa4NmC8dWS6iS1AK3AtskKJUnAD4ADEfHdSs4rqVHSrGz5GuA+4JVKzBoR6yKiOSIWM/J3+duIeKQSs05YufdMAg8wskf7NeBb5c6TZfoZ0A0MMNLwjwJzgXbgUHY7p2D+t7L8B4F/nOSsdzOy2rkb2JX9PFCJeYGPADuzrHuBb2fjFZf1ktyf5f8/HajorOP58RGDZokr9+aAmZWZS8AscS4Bs8S5BMwS5xIwS5xLwCxxLgGzxLkEzBL3f0gk6yn8a/gIAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import gym\n",
    "from matplotlib import pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "#创建环境\n",
    "env = gym.make('Pendulum-v1')\n",
    "env.reset()\n",
    "\n",
    "\n",
    "#打印游戏\n",
    "def show():\n",
    "    plt.imshow(env.render(mode='rgb_array'))\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([5, 11])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "\n",
    "#DuelingDQN和其他DQN模型不同的点,它使用的是不同的模型结构\n",
    "class VAnet(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "        self.fc = torch.nn.Sequential(\n",
    "            torch.nn.Linear(3, 128),\n",
    "            torch.nn.ReLU(),\n",
    "        )\n",
    "\n",
    "        self.fc_A = torch.nn.Linear(128, 11)\n",
    "        self.fc_V = torch.nn.Linear(128, 1)\n",
    "\n",
    "    def forward(self, x):\n",
    "        #[5, 11] -> [5, 128] -> [5, 11]\n",
    "        A = self.fc_A(self.fc(x))\n",
    "\n",
    "        #[5, 11] -> [5, 128] -> [5, 1]\n",
    "        V = self.fc_V(self.fc(x))\n",
    "\n",
    "        #[5, 11] -> [5] -> [5, 1]\n",
    "        A_mean = A.mean(dim=1).reshape(-1, 1)\n",
    "\n",
    "        #[5, 11] - [5, 1] = [5, 11]\n",
    "        A -= A_mean\n",
    "\n",
    "        #Q值由V值和A值计算得到\n",
    "        #[5, 11] + [5, 1] = [5, 11]\n",
    "        Q = A + V\n",
    "\n",
    "        return Q\n",
    "\n",
    "\n",
    "VAnet()(torch.randn(5, 3)).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(VAnet(\n",
       "   (fc): Sequential(\n",
       "     (0): Linear(in_features=3, out_features=128, bias=True)\n",
       "     (1): ReLU()\n",
       "   )\n",
       "   (fc_A): Linear(in_features=128, out_features=11, bias=True)\n",
       "   (fc_V): Linear(in_features=128, out_features=1, bias=True)\n",
       " ),\n",
       " VAnet(\n",
       "   (fc): Sequential(\n",
       "     (0): Linear(in_features=3, out_features=128, bias=True)\n",
       "     (1): ReLU()\n",
       "   )\n",
       "   (fc_A): Linear(in_features=128, out_features=11, bias=True)\n",
       "   (fc_V): Linear(in_features=128, out_features=1, bias=True)\n",
       " ))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "#计算动作的模型,也是真正要用的模型\n",
    "model = VAnet()\n",
    "\n",
    "#经验网络,用于评估一个状态的分数\n",
    "next_model = VAnet()\n",
    "\n",
    "#把model的参数复制给next_model\n",
    "next_model.load_state_dict(model.state_dict())\n",
    "\n",
    "model, next_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0, -2.0)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import random\n",
    "\n",
    "\n",
    "def get_action(state):\n",
    "    #走神经网络,得到一个动作\n",
    "    state = torch.FloatTensor(state).reshape(1, 3)\n",
    "    action = model(state).argmax().item()\n",
    "\n",
    "    if random.random() < 0.01:\n",
    "        action = random.choice(range(11))\n",
    "\n",
    "    #离散动作连续化\n",
    "    action_continuous = action\n",
    "    action_continuous /= 10\n",
    "    action_continuous *= 4\n",
    "    action_continuous -= 2\n",
    "\n",
    "    return action, action_continuous\n",
    "\n",
    "\n",
    "get_action([0.29292667, 0.9561349, 1.0957013])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((200, 0), 200)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#样本池\n",
    "datas = []\n",
    "\n",
    "\n",
    "#向样本池中添加N条数据,删除M条最古老的数据\n",
    "def update_data():\n",
    "    old_count = len(datas)\n",
    "\n",
    "    #玩到新增了N个数据为止\n",
    "    while len(datas) - old_count < 200:\n",
    "        #初始化游戏\n",
    "        state = env.reset()\n",
    "\n",
    "        #玩到游戏结束为止\n",
    "        over = False\n",
    "        while not over:\n",
    "            #根据当前状态得到一个动作\n",
    "            action, action_continuous = get_action(state)\n",
    "\n",
    "            #执行动作,得到反馈\n",
    "            next_state, reward, over, _ = env.step([action_continuous])\n",
    "\n",
    "            #记录数据样本\n",
    "            datas.append((state, action, reward, next_state, over))\n",
    "\n",
    "            #更新游戏状态,开始下一个动作\n",
    "            state = next_state\n",
    "\n",
    "    update_count = len(datas) - old_count\n",
    "    drop_count = max(len(datas) - 5000, 0)\n",
    "\n",
    "    #数据上限,超出时从最古老的开始删除\n",
    "    while len(datas) > 5000:\n",
    "        datas.pop(0)\n",
    "\n",
    "    return update_count, drop_count\n",
    "\n",
    "\n",
    "update_data(), len(datas)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/anaconda3/envs/cpu/lib/python3.6/site-packages/ipykernel_launcher.py:7: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at  /opt/conda/conda-bld/pytorch_1640811701593/work/torch/csrc/utils/tensor_new.cpp:201.)\n",
      "  import sys\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(tensor([[ 2.0041e-01,  9.7971e-01, -2.8199e+00],\n",
       "         [ 4.7866e-01,  8.7800e-01, -1.5885e+00],\n",
       "         [-6.3839e-01,  7.6972e-01,  4.7666e+00],\n",
       "         [-9.0332e-01,  4.2897e-01, -5.4418e+00],\n",
       "         [-4.6531e-01,  8.8515e-01, -5.0707e+00],\n",
       "         [-4.0830e-01, -9.1285e-01, -1.2375e+00],\n",
       "         [ 7.0601e-01, -7.0821e-01, -2.5809e-01],\n",
       "         [-5.5123e-01, -8.3436e-01,  1.8847e+00],\n",
       "         [-3.5083e-01, -9.3644e-01,  6.9238e-01],\n",
       "         [ 3.6800e-01,  9.2983e-01,  1.9958e+00],\n",
       "         [ 4.1285e-01,  9.1080e-01,  2.2697e+00],\n",
       "         [ 6.0357e-01,  7.9731e-01, -2.7398e-01],\n",
       "         [ 4.0749e-01,  9.1321e-01, -1.9734e+00],\n",
       "         [-1.0000e+00, -1.7916e-03, -5.6457e+00],\n",
       "         [-6.9007e-01, -7.2374e-01,  3.0047e+00],\n",
       "         [-5.0700e-01, -8.6195e-01, -2.2222e+00],\n",
       "         [ 5.4799e-01, -8.3649e-01, -1.6400e+00],\n",
       "         [-7.4000e-01, -6.7261e-01, -2.8528e+00],\n",
       "         [ 5.8329e-01,  8.1226e-01,  5.0400e-01],\n",
       "         [-3.4004e-01, -9.4041e-01,  2.3005e-01],\n",
       "         [-4.6885e-02,  9.9890e-01,  3.2634e+00],\n",
       "         [ 5.0131e-01,  8.6527e-01, -1.7845e+00],\n",
       "         [-9.8537e-01,  1.7042e-01, -5.2696e+00],\n",
       "         [-6.5946e-01, -7.5174e-01, -2.5096e+00],\n",
       "         [-5.8631e-01, -8.1009e-01,  2.7019e+00],\n",
       "         [ 5.3360e-01,  8.4573e-01, -1.6032e+00],\n",
       "         [ 5.8376e-01,  8.1193e-01, -1.7884e+00],\n",
       "         [-7.5552e-01,  6.5513e-01, -5.4201e+00],\n",
       "         [ 6.1466e-01, -7.8879e-01, -1.3484e+00],\n",
       "         [-9.7245e-01, -2.3311e-01, -4.7890e+00],\n",
       "         [ 6.5317e-01,  7.5721e-01, -3.6269e-01],\n",
       "         [-3.4898e-01, -9.3713e-01, -1.4615e+00],\n",
       "         [-9.9998e-01, -6.0766e-03,  5.1946e+00],\n",
       "         [ 2.5437e-01,  9.6711e-01,  2.3932e+00],\n",
       "         [ 6.6569e-01,  7.4623e-01, -4.0571e-01],\n",
       "         [ 7.1509e-01, -6.9904e-01,  5.6619e-01],\n",
       "         [-6.5161e-01,  7.5855e-01, -5.5029e+00],\n",
       "         [ 7.0133e-02, -9.9754e-01, -3.1632e+00],\n",
       "         [ 6.7597e-01,  7.3693e-01,  1.6478e+00],\n",
       "         [-9.5383e-01,  3.0036e-01, -5.5429e+00],\n",
       "         [ 7.4633e-01,  6.6557e-01, -7.4006e-01],\n",
       "         [ 3.7493e-01,  9.2705e-01, -2.3628e+00],\n",
       "         [-9.3451e-01, -3.5593e-01,  4.5318e+00],\n",
       "         [-3.5107e-01, -9.3635e-01, -2.3526e-01],\n",
       "         [-5.7678e-01, -8.1690e-01,  2.1447e+00],\n",
       "         [ 5.8615e-01,  8.1020e-01, -1.2689e+00],\n",
       "         [ 7.2120e-01,  6.9273e-01, -9.5961e-01],\n",
       "         [-7.4823e-01, -6.6344e-01,  3.2651e+00],\n",
       "         [-6.6090e-01, -7.5048e-01,  2.4675e+00],\n",
       "         [-8.8942e-01, -4.5709e-01, -4.1462e+00],\n",
       "         [ 6.0740e-02,  9.9815e-01, -3.5685e+00],\n",
       "         [-9.6877e-01, -2.4797e-01,  4.8901e+00],\n",
       "         [-9.8608e-01,  1.6627e-01,  4.9765e+00],\n",
       "         [ 6.5042e-01,  7.5958e-01, -6.7539e-01],\n",
       "         [ 1.1601e-01,  9.9325e-01,  2.8185e+00],\n",
       "         [ 4.5880e-01,  8.8854e-01,  1.6294e+00],\n",
       "         [-5.6051e-01,  8.2814e-01, -5.2287e+00],\n",
       "         [-2.3044e-01,  9.7309e-01,  3.7126e+00],\n",
       "         [-9.5992e-01, -2.8027e-01, -5.1355e+00],\n",
       "         [ 5.3529e-01,  8.4467e-01,  1.8555e+00],\n",
       "         [-7.7324e-01,  6.3412e-01,  4.9930e+00],\n",
       "         [-3.4641e-02,  9.9940e-01, -4.2799e+00],\n",
       "         [-2.5837e-01, -9.6604e-01,  4.3088e-02],\n",
       "         [-9.1477e-01,  4.0398e-01,  4.9735e+00]]),\n",
       " tensor([[ 0],\n",
       "         [ 0],\n",
       "         [ 0],\n",
       "         [ 0],\n",
       "         [ 0],\n",
       "         [ 0],\n",
       "         [ 0],\n",
       "         [ 9],\n",
       "         [ 9],\n",
       "         [ 0],\n",
       "         [ 0],\n",
       "         [ 8],\n",
       "         [ 0],\n",
       "         [ 0],\n",
       "         [ 9],\n",
       "         [10],\n",
       "         [10],\n",
       "         [ 0],\n",
       "         [ 8],\n",
       "         [ 9],\n",
       "         [ 0],\n",
       "         [ 0],\n",
       "         [ 0],\n",
       "         [ 0],\n",
       "         [ 9],\n",
       "         [ 0],\n",
       "         [ 0],\n",
       "         [ 0],\n",
       "         [10],\n",
       "         [ 0],\n",
       "         [ 8],\n",
       "         [10],\n",
       "         [ 0],\n",
       "         [ 0],\n",
       "         [ 8],\n",
       "         [ 0],\n",
       "         [ 0],\n",
       "         [10],\n",
       "         [ 0],\n",
       "         [ 0],\n",
       "         [ 0],\n",
       "         [ 0],\n",
       "         [ 0],\n",
       "         [ 0],\n",
       "         [ 9],\n",
       "         [ 0],\n",
       "         [ 0],\n",
       "         [ 0],\n",
       "         [ 9],\n",
       "         [ 0],\n",
       "         [ 5],\n",
       "         [ 0],\n",
       "         [ 0],\n",
       "         [ 0],\n",
       "         [ 0],\n",
       "         [ 0],\n",
       "         [ 0],\n",
       "         [ 0],\n",
       "         [ 0],\n",
       "         [ 0],\n",
       "         [ 0],\n",
       "         [10],\n",
       "         [ 9],\n",
       "         [ 0]]),\n",
       " tensor([[ -2.6734],\n",
       "         [ -1.4048],\n",
       "         [ -7.3981],\n",
       "         [-10.2458],\n",
       "         [ -6.7973],\n",
       "         [ -4.1228],\n",
       "         [ -0.6300],\n",
       "         [ -5.0002],\n",
       "         [ -3.7725],\n",
       "         [ -1.8278],\n",
       "         [ -1.8307],\n",
       "         [ -0.8606],\n",
       "         [ -1.7185],\n",
       "         [-13.0498],\n",
       "         [ -6.3454],\n",
       "         [ -4.9183],\n",
       "         [ -1.2547],\n",
       "         [ -6.5964],\n",
       "         [ -0.9256],\n",
       "         [ -3.6856],\n",
       "         [ -3.6859],\n",
       "         [ -1.4159],\n",
       "         [-11.6038],\n",
       "         [ -5.8820],\n",
       "         [ -5.5606],\n",
       "         [ -1.2770],\n",
       "         [ -1.2215],\n",
       "         [ -8.8332],\n",
       "         [ -1.0118],\n",
       "         [-10.7442],\n",
       "         [ -0.7525],\n",
       "         [ -3.9320],\n",
       "         [-12.5339],\n",
       "         [ -2.3023],\n",
       "         [ -0.7275],\n",
       "         [ -0.6352],\n",
       "         [ -8.2329],\n",
       "         [ -3.2564],\n",
       "         [ -0.9620],\n",
       "         [-11.1223],\n",
       "         [ -0.5891],\n",
       "         [ -1.9700],\n",
       "         [ -9.7732],\n",
       "         [ -3.7326],\n",
       "         [ -5.2393],\n",
       "         [ -1.0571],\n",
       "         [ -0.6817],\n",
       "         [ -6.9080],\n",
       "         [ -5.8684],\n",
       "         [ -8.8353],\n",
       "         [ -3.5536],\n",
       "         [-10.7532],\n",
       "         [-11.3285],\n",
       "         [ -0.7938],\n",
       "         [ -2.9140],\n",
       "         [ -1.4667],\n",
       "         [ -7.4287],\n",
       "         [ -4.6343],\n",
       "         [-10.8067],\n",
       "         [ -1.3602],\n",
       "         [ -8.5227],\n",
       "         [ -4.4132],\n",
       "         [ -3.3595],\n",
       "         [ -9.9071]]),\n",
       " tensor([[ 3.1554e-01,  9.4891e-01, -2.3851e+00],\n",
       "         [ 5.3172e-01,  8.4692e-01, -1.2300e+00],\n",
       "         [-8.1026e-01,  5.8607e-01,  5.0439e+00],\n",
       "         [-7.5552e-01,  6.5513e-01, -5.4201e+00],\n",
       "         [-2.4609e-01,  9.6925e-01, -4.7068e+00],\n",
       "         [-5.0700e-01, -8.6195e-01, -2.2222e+00],\n",
       "         [ 6.6641e-01, -7.4559e-01, -1.0892e+00],\n",
       "         [-4.8721e-01, -8.7329e-01,  1.4989e+00],\n",
       "         [-3.4004e-01, -9.4041e-01,  2.3005e-01],\n",
       "         [ 2.5437e-01,  9.6711e-01,  2.3932e+00],\n",
       "         [ 2.8877e-01,  9.5740e-01,  2.6528e+00],\n",
       "         [ 5.8329e-01,  8.1226e-01,  5.0400e-01],\n",
       "         [ 4.7866e-01,  8.7800e-01, -1.5885e+00],\n",
       "         [-9.5664e-01,  2.9128e-01, -5.9471e+00],\n",
       "         [-5.8631e-01, -8.1009e-01,  2.7019e+00],\n",
       "         [-6.1322e-01, -7.8991e-01, -2.5686e+00],\n",
       "         [ 4.6319e-01, -8.8626e-01, -1.9674e+00],\n",
       "         [-8.4997e-01, -5.2683e-01, -3.6573e+00],\n",
       "         [ 5.2959e-01,  8.4826e-01,  1.2932e+00],\n",
       "         [-3.5107e-01, -9.3635e-01, -2.3526e-01],\n",
       "         [-2.3044e-01,  9.7309e-01,  3.7126e+00],\n",
       "         [ 5.6208e-01,  8.2709e-01, -1.4356e+00],\n",
       "         [-9.0332e-01,  4.2897e-01, -5.4418e+00],\n",
       "         [-7.7629e-01, -6.3037e-01, -3.3734e+00],\n",
       "         [-4.8798e-01, -8.7285e-01,  2.3343e+00],\n",
       "         [ 5.8615e-01,  8.1020e-01, -1.2689e+00],\n",
       "         [ 6.4217e-01,  7.6656e-01, -1.4794e+00],\n",
       "         [-5.6051e-01,  8.2814e-01, -5.2287e+00],\n",
       "         [ 5.4799e-01, -8.3649e-01, -1.6400e+00],\n",
       "         [-9.9961e-01,  2.7917e-02, -5.2638e+00],\n",
       "         [ 6.3846e-01,  7.6965e-01,  3.8522e-01],\n",
       "         [-4.3470e-01, -9.0057e-01, -1.8644e+00],\n",
       "         [-9.6877e-01, -2.4797e-01,  4.8901e+00],\n",
       "         [ 1.1601e-01,  9.9325e-01,  2.8185e+00],\n",
       "         [ 6.5314e-01,  7.5724e-01,  3.3397e-01],\n",
       "         [ 7.0601e-01, -7.0821e-01, -2.5809e-01],\n",
       "         [-4.3317e-01,  9.0131e-01, -5.2340e+00],\n",
       "         [-1.1015e-01, -9.9391e-01, -3.6114e+00],\n",
       "         [ 6.0300e-01,  7.9774e-01,  1.9005e+00],\n",
       "         [-8.3319e-01,  5.5299e-01, -5.6176e+00],\n",
       "         [ 7.6406e-01,  6.4515e-01, -5.4088e-01],\n",
       "         [ 4.6417e-01,  8.8575e-01, -1.9675e+00],\n",
       "         [-8.4611e-01, -5.3301e-01,  3.9648e+00],\n",
       "         [-4.0830e-01, -9.1285e-01, -1.2375e+00],\n",
       "         [-5.0223e-01, -8.6473e-01,  1.7720e+00],\n",
       "         [ 6.2440e-01,  7.8111e-01, -9.6122e-01],\n",
       "         [ 7.4633e-01,  6.6557e-01, -7.4006e-01],\n",
       "         [-6.6090e-01, -7.5048e-01,  2.4675e+00],\n",
       "         [-5.7678e-01, -8.1690e-01,  2.1447e+00],\n",
       "         [-9.7245e-01, -2.3311e-01, -4.7890e+00],\n",
       "         [ 2.0041e-01,  9.7971e-01, -2.8199e+00],\n",
       "         [-8.9121e-01, -4.5358e-01,  4.4041e+00],\n",
       "         [-9.9734e-01, -7.2951e-02,  4.8012e+00],\n",
       "         [ 6.6569e-01,  7.4623e-01, -4.0571e-01],\n",
       "         [-4.6885e-02,  9.9890e-01,  3.2634e+00],\n",
       "         [ 3.6800e-01,  9.2983e-01,  1.9958e+00],\n",
       "         [-3.4255e-01,  9.3950e-01, -4.9076e+00],\n",
       "         [-4.2563e-01,  9.0490e-01,  4.1424e+00],\n",
       "         [-1.0000e+00, -1.7916e-03, -5.6457e+00],\n",
       "         [ 4.3982e-01,  8.9809e-01,  2.1890e+00],\n",
       "         [-9.0962e-01,  4.1545e-01,  5.1685e+00],\n",
       "         [ 1.2653e-01,  9.9196e-01, -3.2303e+00],\n",
       "         [-2.7963e-01, -9.6011e-01, -4.4145e-01],\n",
       "         [-9.8608e-01,  1.6627e-01,  4.9765e+00]]),\n",
       " tensor([[0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0]]))"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#获取一批数据样本\n",
    "def get_sample():\n",
    "    #从样本池中采样\n",
    "    samples = random.sample(datas, 64)\n",
    "\n",
    "    #[b, 3]\n",
    "    state = torch.FloatTensor([i[0] for i in samples]).reshape(-1, 3)\n",
    "    #[b, 1]\n",
    "    action = torch.LongTensor([i[1] for i in samples]).reshape(-1, 1)\n",
    "    #[b, 1]\n",
    "    reward = torch.FloatTensor([i[2] for i in samples]).reshape(-1, 1)\n",
    "    #[b, 3]\n",
    "    next_state = torch.FloatTensor([i[3] for i in samples]).reshape(-1, 3)\n",
    "    #[b, 1]\n",
    "    over = torch.LongTensor([i[4] for i in samples]).reshape(-1, 1)\n",
    "\n",
    "    return state, action, reward, next_state, over\n",
    "\n",
    "\n",
    "state, action, reward, next_state, over = get_sample()\n",
    "\n",
    "state, action, reward, next_state, over"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.6694],\n",
       "        [ 0.6398],\n",
       "        [ 1.5812],\n",
       "        [ 0.8206],\n",
       "        [ 0.8143],\n",
       "        [ 0.7305],\n",
       "        [ 0.6991],\n",
       "        [ 1.0158],\n",
       "        [ 0.8393],\n",
       "        [ 0.9109],\n",
       "        [ 0.9820],\n",
       "        [ 0.7557],\n",
       "        [ 0.6587],\n",
       "        [ 0.8052],\n",
       "        [ 1.1696],\n",
       "        [ 0.7094],\n",
       "        [ 0.8675],\n",
       "        [ 0.6739],\n",
       "        [ 0.8323],\n",
       "        [ 0.7818],\n",
       "        [ 1.1674],\n",
       "        [ 0.6515],\n",
       "        [ 0.7904],\n",
       "        [ 0.6746],\n",
       "        [ 1.1329],\n",
       "        [ 0.6424],\n",
       "        [ 0.6544],\n",
       "        [ 0.8285],\n",
       "        [ 0.8267],\n",
       "        [ 0.7354],\n",
       "        [ 0.7219],\n",
       "        [ 0.7371],\n",
       "        [ 1.7178],\n",
       "        [ 0.9916],\n",
       "        [ 0.7044],\n",
       "        [ 0.7765],\n",
       "        [ 0.8380],\n",
       "        [ 0.9471],\n",
       "        [ 0.8877],\n",
       "        [ 0.8185],\n",
       "        [ 0.6623],\n",
       "        [ 0.6660],\n",
       "        [ 1.5353],\n",
       "        [ 0.7143],\n",
       "        [ 1.0567],\n",
       "        [ 0.6376],\n",
       "        [ 0.6505],\n",
       "        [ 1.2026],\n",
       "        [ 1.1022],\n",
       "        [ 0.6855],\n",
       "        [-1.5175],\n",
       "        [ 1.6335],\n",
       "        [ 1.6549],\n",
       "        [ 0.6450],\n",
       "        [ 1.0685],\n",
       "        [ 0.8279],\n",
       "        [ 0.8222],\n",
       "        [ 1.2892],\n",
       "        [ 0.7569],\n",
       "        [ 0.8993],\n",
       "        [ 1.6460],\n",
       "        [ 0.7462],\n",
       "        [ 0.7254],\n",
       "        [ 1.6441]], grad_fn=<GatherBackward0>)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_value(state, action):\n",
    "    #使用状态计算出动作的logits\n",
    "    #[b, 3] -> [b, 11]\n",
    "    value = model(state)\n",
    "\n",
    "    #根据实际使用的action取出每一个值\n",
    "    #这个值就是模型评估的在该状态下,执行动作的分数\n",
    "    #在执行动作前,显然并不知道会得到的反馈和next_state\n",
    "    #所以这里不能也不需要考虑next_state和reward\n",
    "    #[b, 11] -> [b, 1]\n",
    "    value = value.gather(dim=1, index=action)\n",
    "\n",
    "    return value\n",
    "\n",
    "\n",
    "get_value(state, action)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ -2.0177],\n",
       "        [ -0.7808],\n",
       "        [ -5.7722],\n",
       "        [ -9.4339],\n",
       "        [ -6.0255],\n",
       "        [ -3.4276],\n",
       "        [  0.1287],\n",
       "        [ -4.0666],\n",
       "        [ -3.0063],\n",
       "        [ -0.8561],\n",
       "        [ -0.7995],\n",
       "        [ -0.0449],\n",
       "        [ -1.0915],\n",
       "        [-12.2143],\n",
       "        [ -5.2352],\n",
       "        [ -4.2404],\n",
       "        [ -0.3679],\n",
       "        [ -5.9364],\n",
       "        [ -0.1700],\n",
       "        [ -2.9856],\n",
       "        [ -2.4225],\n",
       "        [ -0.7906],\n",
       "        [-10.7996],\n",
       "        [ -5.2174],\n",
       "        [ -4.5037],\n",
       "        [ -0.6521],\n",
       "        [ -0.5942],\n",
       "        [ -8.0275],\n",
       "        [ -0.1616],\n",
       "        [ -9.9790],\n",
       "        [  0.0729],\n",
       "        [ -3.2294],\n",
       "        [-10.9331],\n",
       "        [ -1.2552],\n",
       "        [  0.0961],\n",
       "        [  0.0499],\n",
       "        [ -7.4233],\n",
       "        [ -2.3419],\n",
       "        [ -0.0581],\n",
       "        [-10.2997],\n",
       "        [  0.0606],\n",
       "        [ -1.3227],\n",
       "        [ -8.4187],\n",
       "        [ -3.0167],\n",
       "        [ -4.2644],\n",
       "        [ -0.4239],\n",
       "        [ -0.0327],\n",
       "        [ -5.8279],\n",
       "        [ -4.8328],\n",
       "        [ -8.1146],\n",
       "        [ -2.8976],\n",
       "        [ -9.2818],\n",
       "        [ -9.7528],\n",
       "        [ -0.1035],\n",
       "        [ -1.7700],\n",
       "        [ -0.5740],\n",
       "        [ -6.6416],\n",
       "        [ -3.2530],\n",
       "        [-10.0177],\n",
       "        [ -0.4138],\n",
       "        [ -6.8615],\n",
       "        [ -3.7434],\n",
       "        [ -2.6702],\n",
       "        [ -8.2853]])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_target(reward, next_state, over):\n",
    "    #上面已经把模型认为的状态下执行动作的分数给评估出来了\n",
    "    #下面使用next_state和reward计算真实的分数\n",
    "    #针对一个状态,它到底应该多少分,可以使用以往模型积累的经验评估\n",
    "    #这也是没办法的办法,因为显然没有精确解,这里使用延迟更新的next_model评估\n",
    "\n",
    "    #使用next_state计算下一个状态的分数\n",
    "    #[b, 3] -> [b, 11]\n",
    "    with torch.no_grad():\n",
    "        target = next_model(next_state)\n",
    "\n",
    "    #取所有动作中分数最大的\n",
    "    #[b, 11] -> [b, 1]\n",
    "    target = target.max(dim=1)[0]\n",
    "    target = target.reshape(-1, 1)\n",
    "\n",
    "    #下一个状态的分数乘以一个系数,相当于权重\n",
    "    target *= 0.98\n",
    "\n",
    "    #如果next_state已经游戏结束,则next_state的分数是0\n",
    "    #因为如果下一步已经游戏结束,显然不需要再继续玩下去,也就不需要考虑next_state了.\n",
    "    #[b, 1] * [b, 1] -> [b, 1]\n",
    "    target *= (1 - over)\n",
    "\n",
    "    #加上reward就是最终的分数\n",
    "    #[b, 1] + [b, 1] -> [b, 1]\n",
    "    target += reward\n",
    "\n",
    "    return target\n",
    "\n",
    "\n",
    "get_target(reward, next_state, over)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-991.1090374608216"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from IPython import display\n",
    "\n",
    "\n",
    "def test(play):\n",
    "    #初始化游戏\n",
    "    state = env.reset()\n",
    "\n",
    "    #记录反馈值的和,这个值越大越好\n",
    "    reward_sum = 0\n",
    "\n",
    "    #玩到游戏结束为止\n",
    "    over = False\n",
    "    while not over:\n",
    "        #根据当前状态得到一个动作\n",
    "        _, action_continuous = get_action(state)\n",
    "\n",
    "        #执行动作,得到反馈\n",
    "        state, reward, over, _ = env.step([action_continuous])\n",
    "        reward_sum += reward\n",
    "\n",
    "        #打印动画\n",
    "        if play and random.random() < 0.2:  #跳帧\n",
    "            display.clear_output(wait=True)\n",
    "            show()\n",
    "\n",
    "    return reward_sum\n",
    "\n",
    "\n",
    "test(play=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "id": "OHoSU6uI-xIt",
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 400 200 0 -1502.0604868258624\n",
      "20 4400 200 0 -1255.668405491158\n",
      "40 5000 200 200 -961.9330424508095\n",
      "60 5000 200 200 -213.30234818476475\n",
      "80 5000 200 200 -332.36341367063903\n",
      "100 5000 200 200 -363.6590379047417\n",
      "120 5000 200 200 -192.13057207138917\n",
      "140 5000 200 200 -154.31394867981035\n",
      "160 5000 200 200 -200.26863291597485\n",
      "180 5000 200 200 -130.13165210086567\n"
     ]
    }
   ],
   "source": [
    "def train():\n",
    "    model.train()\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)\n",
    "    loss_fn = torch.nn.MSELoss()\n",
    "\n",
    "    #训练N次\n",
    "    for epoch in range(200):\n",
    "        #更新N条数据\n",
    "        update_count, drop_count = update_data()\n",
    "\n",
    "        #每次更新过数据后,学习N次\n",
    "        for i in range(200):\n",
    "            #采样一批数据\n",
    "            state, action, reward, next_state, over = get_sample()\n",
    "\n",
    "            #计算一批样本的value和target\n",
    "            value = get_value(state, action)\n",
    "            target = get_target(reward, next_state, over)\n",
    "\n",
    "            #更新参数\n",
    "            loss = loss_fn(value, target)\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            #把model的参数复制给next_model\n",
    "            if (i + 1) % 50 == 0:\n",
    "                next_model.load_state_dict(model.state_dict())\n",
    "\n",
    "        if epoch % 20 == 0:\n",
    "            test_result = sum([test(play=False) for _ in range(20)]) / 20\n",
    "            print(epoch, len(datas), update_count, drop_count, test_result)\n",
    "\n",
    "\n",
    "train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQEAAAD8CAYAAAB3lxGOAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAXWUlEQVR4nO3deZSU9Z3v8fe3it6gIbJ2Wpp9GUcvJtAdZjzROxw1ESZzXEjIwaOR5DAhUS8M8eZM4BrnZuIJ6EQdJ2c0iUuOBBeGcwUlixqCcUQvuywii3QCNNh92QQUu+mmq773j35kKthQ1d1VXcvzeZ1Tp5769fM89S2K/vTv2X6PuTsiEl6RbBcgItmlEBAJOYWASMgpBERCTiEgEnIKAZGQy1gImNlkM9ttZrVmNi9T7yMiXWOZOE/AzKLAu8AXgIPABuBmd9+R9jcTkS7JVE9gIlDr7n9y9xZgCXBDht5LRLqgR4bWOxg4kPD6IPBX55t5wIABPnz48AyVIiIAmzZtOuruA89tz1QIWDttf7bdYWazgFkAQ4cOZePGjRkqRUQAzGx/e+2Z2hw4CAxJeF0F1CfO4O6PuXuNu9cMHPiJcBKRbpKpENgAjDGzEWZWDEwHVmTovUSkCzKyOeDurWb2P4BXgCjwC3d/JxPvJSJdk6l9Arj7b4HfZmr9IpIeOmNQJOQUAiIhpxAQCTmFgEjIKQREQk4hIBJyCgGRkFMIiIScQkAk5BQCIiGnEBAJOYWASMgpBERCTiEgEnIKAZGQUwiIhJxCQCTkFAIiIacQEAk5hYBIyCkEREJOISAScgoBkZBTCIiEnEJAJOQUAiIhpxAQCTmFgEjIKQREQk4hIBJyCgGRkFMIiIScQkAk5BQCIiGnEBAJuaQhYGa/MLPDZrY9oa2fma00sz3Bc9+En803s1oz221m12WqcBFJj1R6Ak8Bk89pmwescvcxwKrgNWZ2KTAduCxY5lEzi6atWhFJu6Qh4O6vA++f03wDsCiYXgTcmNC+xN2b3X0vUAtMTE+pIpIJnd0nUOHuDQDB86CgfTBwIGG+g0HbJ5jZLDPbaGYbjxw50skyRKSr0r1j0Npp8/ZmdPfH3L3G3WsGDhyY5jJEJFWdDYFDZlYJEDwfDtoPAkMS5qsC6jtfnohkWmdDYAUwI5ieAbyY0D7dzErMbAQwBljftRJFJJN6JJvBzJ4DJgEDzOwg8L+B+4ClZjYTqAOmAbj7O2a2FNgBtAJ3unssQ7WLSBokDQF3v/k8P7rmPPP/CPhRV4oSke6jMwZFQk4hIBJyCgGRkFMIiIRc0h2DIoncnVhjIx/t3EnTvn0QidBz5Eh6/cVfECktxay988UklykEJGXuTtO+fRx84glO7dqFnzkDgBUX03vcOIb8/d9TcvHFCoI8o80BSVlzQwN7H3yQD99++2wAAHhLCx9s2sTef/1Xzhw7lsUKpTMUApISb22lYelSTtfVnXeexnff5dALL+AxnR+WTxQCkpIzJ09ycn3yM8CPv/kmsaambqhI0kUhIEm5Ox9s3kyssTH5vOoF5B2FgKSkuaEB4vFslyEZoBCQ5NyJJ+wIlMKiEJCUeEtLtkuQDFEISHLuxBUCBUshIMlpc6CgKQQkKXcndupUSvNGiotBZwzmFYWAJOWxGKcPHEg+I1A6dCiRoqIMVyTppBCQ1Hi7g0Z/gnoC+UchIGkVKSrSBUR5RiEgaWXqCeQdhYCkVaSoSCGQZxQCklamEMg7CgFJymMxPMUdgxbRf6l8o29MkvIzZzp08ZB2DOYXhYAk5a2tKfcEJP8oBCSpeAd7ApJfFAKSlLe24gqBgqUQkKTiLS0p9wS0YzD/6BuTpJrr61MaWoxIhLIRIzJfkKSVQkCS8lgspWsHzIxIWVk3VCTppBCQ9DFru4BI8opCQNJKIZB/FAKSVqaxBPKO7kUoF+TubSMLBc8AETMi7Z0VqM2BvJQ0BMxsCPBL4NNAHHjM3f/NzPoB/wEMB/YBX3X348Ey84GZQAyY4+6vZKR6STt3Jx6Pc/LkSXbv3s3WrVvZ8uqr7NqwgeMtLTTHYtxxySV8vqKi3eWth/6u5JtUvrFW4H+6+1tm1hvYZGYrga8Dq9z9PjObB8wDvmdmlwLTgcuAi4Hfm9lYd9etaXKYu9PY2MjWrVt5/vnnWbVqFfv378fd6RmN0rOlhYuKi/lUaSk9L/SLrusG8k7SEHD3BqAhmP7QzHYCg4EbgEnBbIuA14DvBe1L3L0Z2GtmtcBEYE26i5euc3dOnz7NqlWr+NnPfsYbb7xBSUkJ1dXV3HbbbVRXV1O+ezenX3yRqBkG7W8KSN7qUN/NzIYD44F1QEUQELh7g5kNCmYbDKxNWOxg0CY55ONu/7Zt21iwYAGvvPIKlZWV3HXXXXzlK19h1KhRFAfb9//v8GHqo9EsVyyZknIImFk58Dww190/uMDlou394BNnmpjZLGAWwNChQ1MtQ9KkqamJp556ivvuu49YLMZ3vvMdZs6cSVVVFZGEU39dNx4peCmFgJkV0RYAz7j7sqD5kJlVBr2ASuBw0H4QGJKweBVQf+463f0x4DGAmpoaXafaTdyd999/nx/84AcsWrSIK664gnvvvZcJEybQ4zzb+roFWWFLep6Atf3JfxLY6e4PJfxoBTAjmJ4BvJjQPt3MSsxsBDAGSH5je8k4d+fQoUPccccdLFq0iG9961s899xzfO5znztvAAA0/ulPKa0/WlZGREcH8k4q39jnga8Bb5vZlqDtfwH3AUvNbCZQB0wDcPd3zGwpsIO2Iwt36shA9rk7R48eZc6cOaxatYof/vCHfPvb36a0tDTZgqldPAQUDxpEpGfPNFQr3SmVowNv0P52PsA151nmR8CPulCXpFlTUxPf//73Wbly5dkAuNBf/86wHj10KXEe0jcWArFYjJ///Oc8++yzzJ49m1mzZtGjR4+0jwWoEMhP+sYKnLuzZs0a7r//fqZMmcJ3v/tdiouLMzIYqPXoAQqBvKO9OAXu5MmT3HvvvfTp0+fsc6ZEiorUE8hDCoEC5u4sW7aMtWvX8sADDzB27NiMvp96AvlJ31gBO3bsGI888gjV1dV89atf7dQ6PBZrG1koFTqdOC+pJ1Cg3J2XXnqJPXv28MQTT9CnT59O7QfoUAigG4/kI/UEClRjYyPPPPMMl1xyCddee22nfzk9FoMOhIDkH4VAgdqzZw8bNmzgpptuom/fvp1ej7e2dqgnIPlHIVCA3J1XX30VM2PKlCld6qJ3dHOg0+/jzrFjx4jrJifdTiFQgJqbm1m9ejWjR49m9OjRXVpX6wcf0PrhhynN25WhxVpaWnj44Yc5fPhw8pklrRQCBejUqVO8/fbbVFdX06tXry6tK97YSLypKaV5e44a1en32b9/P7/85S9Zu3atbn7azRQCBaiuro7jx48zYcKEbn3fSLKLkc7D3fnNb35DfX09L7zwAjHtg+hWCoECdODAAdydEd18S7DObg589NFHLF++nHg8zuuvv05DQ0OaK5MLUQgUoKNHj9KjRw8GDBiQ9uP2H7S08FRtLT/evp1t77//Z133zoSAu7N582a2bt0KwHvvvcebb76pTYJupBAoQM3NzRQVFaX9OoEPz5zhnzZv5t937uQ/9u7lO+vXs/bIkbM/j5SUdHidH5/afOrUKQBaW1v51a9+xZkzZ9JWt1yYQqBARaPRswOFdnFFbQ/gvcZG/m/C3vuTZ87wu/r/GjnOOvF+hw8f5qWXXqJXr16YGUVFRaxevZr6+k+MSCcZohAoUB+PJtxVPUeOpGzYMACKIxFKzhl1uE/Cbcc6OrSYu7N69Wr69+/PwoULKSsr44477mDYsGG8/vrr2iToJgqBAmRmnDlzhsYUhwW74Lp69KDP+PEAjCgv53vjxjGgpISSaJSrKyuZOWYM0Da0WPGnP92hdcdiMdyd5557jvHjx+PuTJo0icWLFxOJRBQC3UQXEBWg3r1709rayokTJ3D3Lu0cNDMGTpnCB5s20bRvH1+qqmJC//40tbYyuFcvSqNRrLiYiqlTKerg6cmRSISpU6cSjUbZvHkzkUiEvn37MmzYMC6++GJdjNRN1BMoQJWVlcRisbQdaiseOJChd95J2YgRWCTCxT17MqpPH0qjUSI9e/LpadMY0ImLlCKRyNlxDuvq6igtLaV///6YWcZGP5JPUk+gAA0fPpzi4mJ27drF9ddf3+X1mRm9xo5l9D338P5//icfbttG7PRpyoYMod/f/A29/vIvuzTUeDweZ9euXfTr149BgwYlX0DSSiFQgAYNGsTQoUPZsGED8XicaBpuIWZmFA8YQMXUqVRMnQruZwcR6epf7ObmZrZs2cLo0aO56KKLulyrdIw2BwpQWVkZEydOZNu2bRxJOI6fDmbW9ohEzk53hbuzf/9+amtrufLKK9MSWNIxCoECFI1Gufrqq2loaGDz5s3ZLiep1atX09raylVXXZXtUkJJIVCAzIwrrriCyspKli1bltNn37W0tLB8+XLGjh3LuHHjtDMwCxQCBaqiooLJkyfz8ssvs3fv3pw95v7WW2+xYcMGvvzlL1NeXp7tckJJIVCgotEot9xyC01NTSxevDgnR+xpaWnh8ccfp1evXtx0001/dkt06T76Vy9g48eP57rrrmPx4sW8++67OdUbcHfWrVvHihUruPnmmxk+fHi2SwothUABKyoqYu7cuTQ3N/PjH/+YlpaWbJcEtAXAyZMnWbhwIX379mXWrFk6KpBFCoECZmaMHz+emTNnsmzZsrMDd2RbPB7nySef5M033+Suu+5SLyDLFAIFLhqNMnv2bC6//HLuuecetm3bltXNAnfntdde44EHHuALX/gCt956q44IZJlCoMCZGYMGDeLBBx+kpaWF2bNns2/fvqwEgbuzfft25syZQ0VFBQsWLKC8vFwhkGUKgRAwM6qrq3nwwQfZvXs3t99+O3V1dd0aBO7Ojh07+OY3v8lHH33ET37yE8aMGaMAyAEKgZCIRCLceOONLFy4kE2bNvH1r3+dXbt2dUsQxONx1q1bx4wZM6ivr+fRRx/lyiuvVADkCIVAiESjUW677TYeeughdu3axbRp03j55ZdpbW3NSBi4O83NzSxZsoSbb76ZEydO8PjjjzN58mSdE5BDkn4TZlZqZuvNbKuZvWNm/xy09zOzlWa2J3jum7DMfDOrNbPdZnZdJj+ApO7jMfxuueUWnn76aaLRKF/72te4++67qa+vT+uRg1gsRm1tLbNnz+b222+nqqqKJUuW8MUvflEBkGvc/YIPwIDyYLoIWAf8NfAvwLygfR5wfzB9KbAVKAFGAH8Eohd6j+rqapfuFY/Hfc+ePT5jxgz/1Kc+5Zdffrn/9Kc/9YaGBo/FYp1eb2trq+/du9cXLFjgo0eP9v79+/vcuXP9vffe83g8nsZPIB0FbPT2fsfbazzfA+gJvAX8FbAbqAzaK4HdwfR8YH7CMq8AV1xovQqB7IjH497Y2OjLly/3q666ysvLy/2yyy7zu+++29esWePHjx/3WCzm8Xi83V/gj9tjsZgfPXrU//CHP/icOXN85MiRXl5e7pMnT/bf//733tzcrADIAecLAfMUtgXNLApsAkYDj7j798zshLtflDDPcXfva2b/Dqx196eD9ieBl9z9/5yzzlnALIChQ4dW79+/P5WOi2SAB2fw/e53v+Opp55i7dq1xONxhg8fzmc+8xk++9nPMnLkSPr27UtZWRnuTmNjI8eOHaO2tpYtW7awbds26urqKCsr46qrruIb3/gGkyZNOjuUuGSfmW1y95pPtKcSAgkruQhYDswG3jhPCDwCrDknBH7r7s+fb701NTW+cePGlOuQzHB3mpqa2LlzJytXrmT16tXs2LGDEydO0NraSjQaPTsKcDweJx6PU1RURP/+/Rk3bhyTJk3immuuYdSoUZSUlOiXP8ecLwQ6NLyYu58ws9eAycAhM6t09wYzqwQ+vivFQWBIwmJVgO4kkQfMjJ49ezJhwgTGjx/P3LlzOXnyJAcPHqS+vp4TJ05w+vRpAEpLS+nXrx+DBw+mqqqK3r17a3DQPJU0BMxsIHAmCIAy4FrgfmAFMAO4L3h+MVhkBfCsmT0EXAyMAdZnoHbJkI+HDSstLaW0tJSKigqqq6uzXZZkSCo9gUpgUbBfIAIsdfdfm9kaYKmZzQTqgGkA7v6OmS0FdgCtwJ3urntNi+SoDu0TyBTtExDJvPPtE9BZGyIhpxAQCTmFgEjIKQREQk4hIBJyCgGRkFMIiIScQkAk5BQCIiGnEBAJOYWASMgpBERCTiEgEnIKAZGQUwiIhJxCQCTkFAIiIacQEAk5hYBIyCkEREJOISAScgoBkZBTCIiEnEJAJOQUAiIhpxAQCTmFgEjIKQREQk4hIBJyCgGRkFMIiIScQkAk5BQCIiGnEBAJuZRDwMyiZrbZzH4dvO5nZivNbE/w3Ddh3vlmVmtmu83sukwULiLp0ZGewD8AOxNezwNWufsYYFXwGjO7FJgOXAZMBh41s2h6yhWRdEspBMysCvgS8ERC8w3AomB6EXBjQvsSd292971ALTAxLdWKSNql2hN4GPhHIJ7QVuHuDQDB86CgfTBwIGG+g0GbiOSgpCFgZn8HHHb3TSmu09pp83bWO8vMNprZxiNHjqS4ahFJt1R6Ap8HrjezfcAS4Gozexo4ZGaVAMHz4WD+g8CQhOWrgPpzV+ruj7l7jbvXDBw4sAsfQUS6ImkIuPt8d69y9+G07fB71d1vBVYAM4LZZgAvBtMrgOlmVmJmI4AxwPq0Vy4iadGjC8veByw1s5lAHTANwN3fMbOlwA6gFbjT3WNdrlREMsLcP7G53u1qamp848aN2S5DpKCZ2SZ3rzm3XWcMioScQkAk5BQCIiGnEBAJOYWASMgpBERCTiEgEnIKAZGQUwiIhJxCQCTkFAIiIacQEAk5hYBIyCkEREJOISAScgoBkZBTCIiEnEJAJOQUAiIhpxAQCTmFgEjIKQREQk4hIBJyCgGRkFMIiIScQkAk5BQCIiGnEBAJOYWASMgpBERCTiEgEnIKAZGQUwiIhJxCQCTkFAIiIacQEAk5hYBIyCkERELO3D3bNWBmR4CPgKPZriVFA8ifWiG/6lWtmTPM3Qee25gTIQBgZhvdvSbbdaQin2qF/KpXtXY/bQ6IhJxCQCTkcikEHst2AR2QT7VCftWrWrtZzuwTEJHsyKWegIhkQdZDwMwmm9luM6s1s3nZrgfAzH5hZofNbHtCWz8zW2lme4Lnvgk/mx/Uv9vMruvmWoeY2R/MbKeZvWNm/5Cr9ZpZqZmtN7OtQa3/nKu1Jrx/1Mw2m9mvc73WTnP3rD2AKPBHYCRQDGwFLs1mTUFd/x2YAGxPaPsXYF4wPQ+4P5i+NKi7BBgRfJ5oN9ZaCUwIpnsD7wY15Vy9gAHlwXQRsA7461ysNaHmu4BngV/n8v+Drjyy3ROYCNS6+5/cvQVYAtyQ5Zpw99eB989pvgFYFEwvAm5MaF/i7s3uvheope1zdQt3b3D3t4LpD4GdwOBcrNfbnApeFgUPz8VaAcysCvgS8ERCc07W2hXZDoHBwIGE1weDtlxU4e4N0PaLBwwK2nPmM5jZcGA8bX9hc7LeoHu9BTgMrHT3nK0VeBj4RyCe0JartXZatkPA2mnLt8MVOfEZzKwceB6Y6+4fXGjWdtq6rV53j7n7Z4EqYKKZ/bcLzJ61Ws3s74DD7r4p1UXaacuL/8vZDoGDwJCE11VAfZZqSeaQmVUCBM+Hg/asfwYzK6ItAJ5x92VBc87WC+DuJ4DXgMnkZq2fB643s320baZebWZP52itXZLtENgAjDGzEWZWDEwHVmS5pvNZAcwIpmcALya0TzezEjMbAYwB1ndXUWZmwJPATnd/KJfrNbOBZnZRMF0GXAvsysVa3X2+u1e5+3Da/l++6u635mKtXZbtPZPA39K2R/uPwN3Zrieo6TmgAThDW8LPBPoDq4A9wXO/hPnvDurfDUzp5lqvpK3buQ3YEjz+NhfrBS4HNge1bgf+KWjPuVrPqXsS/3V0IKdr7cxDZwyKhFy2NwdEJMsUAiIhpxAQCTmFgEjIKQREQk4hIBJyCgGRkFMIiITc/wdxZIU2TRAQiAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "-1.0533274946868918"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test(play=True)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "第7章-DQN算法.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
